{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ndo4ERqnwQOU"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "MTKwbguKwT4R"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xfNT-mlFwxVM"
      },
      "source": [
        "# 畳み込み変分オートエンコーダ"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0TD5ZrvEMbhZ"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/generative/cvae\">     <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">     TensorFlow.org で表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/tutorials/generative/cvae.ipynb\">     <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">     Google Colab で実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/tutorials/generative/cvae.ipynb\">     <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">      GitHub でソースを表示</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/tutorials/generative/cvae.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ITZuApL56Mny"
      },
      "source": [
        "このノートブックでは、MNIST データセットで変分オートエンコーダ（VAE）（[1](https://arxiv.org/abs/1312.6114)、[2](https://arxiv.org/abs/1401.4082)）のトレーニング方法を実演します。VAE はオートエンコードの確率論的見解で、高次元入力データをより小さな表現に圧縮するモデルです。入力を潜在ベクトルにマッピングする従来のオートエンコーダとは異なり、VAE は入力を、ガウスの平均と分散といった、確率分布のパラメータにマッピングします。このアプローチによって、画像生成に役立つ、構造化された連続潜在空間が生成されます。\n",
        "\n",
        "![CVAE image latent space](images/cvae_latent_space.jpg)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e1_Y75QXJS6h"
      },
      "source": [
        "## TensorFlow とその他のライブラリをインポートする"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P-JuIu2N_SQf"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow-probability\n",
        "\n",
        "# to generate gifs\n",
        "!pip install imageio\n",
        "!pip install git+https://github.com/tensorflow/docs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YfIk2es3hJEd"
      },
      "outputs": [],
      "source": [
        "from IPython import display\n",
        "\n",
        "import glob\n",
        "import imageio\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import PIL\n",
        "import tensorflow as tf\n",
        "import tensorflow_probability as tfp\n",
        "import time"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iYn4MdZnKCey"
      },
      "source": [
        "## MNIST データセットを読み込む\n",
        "\n",
        "それぞれの MNIST 画像は、もともと 784 個の整数値から成るベクトルで、各整数値は、ピクセルの強度を表す 0～255 の値です。各ピクセルをベルヌーイ分布を使ってモデル化し、データセットを統計的にバイナリ化します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a4fYMGxGhrna"
      },
      "outputs": [],
      "source": [
        "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NFC2ghIdiZYE"
      },
      "outputs": [],
      "source": [
        "def preprocess_images(images):\n",
        "  images = images.reshape((images.shape[0], 28, 28, 1)) / 255.\n",
        "  return np.where(images &gt; .5, 1.0, 0.0).astype('float32')\n",
        "\n",
        "train_images = preprocess_images(train_images)\n",
        "test_images = preprocess_images(test_images)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S4PIDhoDLbsZ"
      },
      "outputs": [],
      "source": [
        "train_size = 60000\n",
        "batch_size = 32\n",
        "test_size = 10000"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PIGN6ouoQxt3"
      },
      "source": [
        "## *tf.data* を使用して、データをバッチ化・シャッフルする"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-yKCCQOoJ7cn"
      },
      "outputs": [],
      "source": [
        "train_dataset = (tf.data.Dataset.from_tensor_slices(train_images)\n",
        "                 .shuffle(train_size).batch(batch_size))\n",
        "test_dataset = (tf.data.Dataset.from_tensor_slices(test_images)\n",
        "                .shuffle(test_size).batch(batch_size))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "THY-sZMiQ4UV"
      },
      "source": [
        "## *tf.keras.Sequential* を使ってエンコーダとデコーダネットワークを定義する\n",
        "\n",
        "この VAE の例では、エンコーダとデコーダのネットワークに 2 つの小さな ConvNets を使用しています。文献では、これらのネットワークはそれぞれ、推論/認識モデルおよび生成モデルとも呼ばれています。実装を簡略化するために `tf.keras.Sequential` を使用しています。以降の説明では、観測と潜在変数をそれぞれ $x$ と $z$ で表記しています。\n",
        "\n",
        "### エンコーダネットワーク\n",
        "\n",
        "これは、おおよその事後分布 $q(z|x)$ を定義します。これは、観測を入力として受け取り、潜在表現 $z$ の条件付き分布を指定するための一連のパラメータを出力します。この例では、単純に対角ガウスとして分布をモデル化するため、ネットワークは、素因数分解されたガウスの平均と対数分散を出力します。数値的な安定性を得るために、直接分散を出力する代わりに、対数分散を出力します。\n",
        "\n",
        "### デコーダネットワーク\n",
        "\n",
        "これは、観測 $p(x|z)$ の条件付き分布を定義します。これは入力として潜在サンプル $z$ を取り、観測の条件付き分布のパラメータを出力します。$p(z)$ 前の潜在分布を単位ガウスとしてモデル化します。\n",
        "\n",
        "### パラメータ再設定のコツ\n",
        "\n",
        "トレーニング中にデコーダのサンプル $z$ を生成するには、エンコーダが出力したパラメータによって定義される潜在分布から入力観測 $x$ を指定してサンプリングできます。ただし、このサンプリング演算では、バックプロパゲーションがランダムなノードを通過できないため、ボトルネックが発生します。\n",
        "\n",
        "これを解消するために、パラメータ再設定のコツを使用します。この例では、次のように、デコーダパラメータともう 1 つの $\\epsilon$ を使用して $z$ を近似化します。\n",
        "\n",
        "$$z = \\mu + \\sigma \\odot \\epsilon$$\n",
        "\n",
        "上記の $\\mu$ と $\\sigma$ は、それぞれガウス分布の平均と標準偏差を表します。これらは、デコーダ出力から得ることができます。$\\epsilon$ は、$z$ の偶然性を維持するためのランダムノイズとして捉えることができます。$\\epsilon$ は標準正規分布から生成します。\n",
        "\n",
        "潜在変数 $z$ は、$\\mu$、$\\sigma$、および $\\epsilon$ の関数によって生成されるようになりました。これらによって、モデルがそれぞれ $\\mu$ と $\\sigma$ を介してエンコーダの勾配をバックプロパゲートしながら、$\\epsilon$ を介して偶然性を維持できるようになります。\n",
        "\n",
        "### ネットワークアーキテクチャ\n",
        "\n",
        "エンコーダネットワークの場合、2 つの畳み込みレイヤーを使用し、その後に全結合レイヤーが続きます。デコーダネットワークでは、3 つの畳み込み転置レイヤー（一部の文脈ではデコンボリューションレイヤーとも呼ばれる）が続く全結合レイヤーを使用することで、このアーキテクチャをミラーリングします。VAE を使用する場合にはバッチの正規化を回避するのが一般的であることに注意してください。これは、ミニバッチを使用することで追加される偶然性によって、サンプリングの偶然性にさらに不安定性を加える可能性があるためです。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VGLbvBEmjK0a"
      },
      "outputs": [],
      "source": [
        "class CVAE(tf.keras.Model):\n",
        "  \"\"\"Convolutional variational autoencoder.\"\"\"\n",
        "\n",
        "  def __init__(self, latent_dim):\n",
        "    super(CVAE, self).__init__()\n",
        "    self.latent_dim = latent_dim\n",
        "    self.encoder = tf.keras.Sequential(\n",
        "        [\n",
        "            tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n",
        "            tf.keras.layers.Conv2D(\n",
        "                filters=32, kernel_size=3, strides=(2, 2), activation='relu'),\n",
        "            tf.keras.layers.Conv2D(\n",
        "                filters=64, kernel_size=3, strides=(2, 2), activation='relu'),\n",
        "            tf.keras.layers.Flatten(),\n",
        "            # No activation\n",
        "            tf.keras.layers.Dense(latent_dim + latent_dim),\n",
        "        ]\n",
        "    )\n",
        "\n",
        "    self.decoder = tf.keras.Sequential(\n",
        "        [\n",
        "            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n",
        "            tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n",
        "            tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n",
        "            tf.keras.layers.Conv2DTranspose(\n",
        "                filters=64, kernel_size=3, strides=2, padding='same',\n",
        "                activation='relu'),\n",
        "            tf.keras.layers.Conv2DTranspose(\n",
        "                filters=32, kernel_size=3, strides=2, padding='same',\n",
        "                activation='relu'),\n",
        "            # No activation\n",
        "            tf.keras.layers.Conv2DTranspose(\n",
        "                filters=1, kernel_size=3, strides=1, padding='same'),\n",
        "        ]\n",
        "    )\n",
        "\n",
        "  @tf.function\n",
        "  def sample(self, eps=None):\n",
        "    if eps is None:\n",
        "      eps = tf.random.normal(shape=(100, self.latent_dim))\n",
        "    return self.decode(eps, apply_sigmoid=True)\n",
        "\n",
        "  def encode(self, x):\n",
        "    mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)\n",
        "    return mean, logvar\n",
        "\n",
        "  def reparameterize(self, mean, logvar):\n",
        "    eps = tf.random.normal(shape=mean.shape)\n",
        "    return eps * tf.exp(logvar * .5) + mean\n",
        "\n",
        "  def decode(self, z, apply_sigmoid=False):\n",
        "    logits = self.decoder(z)\n",
        "    if apply_sigmoid:\n",
        "      probs = tf.sigmoid(logits)\n",
        "      return probs\n",
        "    return logits"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0FMYgY_mPfTi"
      },
      "source": [
        "## 損失関数とオプティマイザを定義する\n",
        "\n",
        "VAE は、限界対数尤度の証拠下限（ELBO）を最大化することでトレーニングします。\n",
        "\n",
        "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n",
        "\n",
        "実際には、この期待値の単一サンプルのモンテカルロ推定を最適化します。\n",
        "\n",
        "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$ とし、$z$ は $q(z|x)$ からサンプリングされます。\n",
        "\n",
        "**注意**: KL 項を分析的に計算することもできますが、ここでは簡単にするために、3 つの項すべてをモンテカルロ Estimator に組み込んでいます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iWCn_PVdEJZ7"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.keras.optimizers.Adam(1e-4)\n",
        "\n",
        "\n",
        "def log_normal_pdf(sample, mean, logvar, raxis=1):\n",
        "  log2pi = tf.math.log(2. * np.pi)\n",
        "  return tf.reduce_sum(\n",
        "      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n",
        "      axis=raxis)\n",
        "\n",
        "\n",
        "def compute_loss(model, x):\n",
        "  mean, logvar = model.encode(x)\n",
        "  z = model.reparameterize(mean, logvar)\n",
        "  x_logit = model.decode(z)\n",
        "  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n",
        "  logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n",
        "  logpz = log_normal_pdf(z, 0., 0.)\n",
        "  logqz_x = log_normal_pdf(z, mean, logvar)\n",
        "  return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n",
        "\n",
        "\n",
        "@tf.function\n",
        "def train_step(model, x, optimizer):\n",
        "  \"\"\"Executes one training step and returns the loss.\n",
        "\n",
        "  This function computes the loss and gradients, and uses the latter to\n",
        "  update the model's parameters.\n",
        "  \"\"\"\n",
        "  with tf.GradientTape() as tape:\n",
        "    loss = compute_loss(model, x)\n",
        "  gradients = tape.gradient(loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rw1fkAczTQYh"
      },
      "source": [
        "## トレーニング\n",
        "\n",
        "- データセットのイテレーションから始めます。\n",
        "- 各イテレーションで、画像をエンコーダに渡して、おおよその事後分布 $q(z|x)$ 一連の平均値と対数分散パラメータを取得します。\n",
        "- 次に、$q(z|x)$ から得たサンプルに*パラメータ再設定のコツ*を適用します。\n",
        "- 最後に、パラメータを再設定したサンプルをデコーダに渡して、生成分布 $p(x|z)$ のロジットを取得します。\n",
        "- **注意:** トレーニングセットに 60k のデータポイントとテストセットに 10k のデータポイント持つ Keras で読み込んだデータセットを使用しているため、テストセットの ELBO は、Larochelle の MNIST の動的なバイナリ化を使用する、文献で報告された結果よりもわずかに高くなります。\n",
        "\n",
        "### 画像の生成\n",
        "\n",
        "- トレーニングの後は、画像をいくつか生成します。\n",
        "- 分布 $p(z)$ 前の単位ガウスから一連の潜在ベクトルをサンプリングすることから始めます。\n",
        "- すると、ジェネレータはその潜在サンプル $z$ を観測のロジットに変換し、分布 $p(x|z)$ が得られます。\n",
        "- ここで、ベルヌーイ分布の確率を図に作成します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NS2GWywBbAWo"
      },
      "outputs": [],
      "source": [
        "epochs = 10\n",
        "# set the dimensionality of the latent space to a plane for visualization later\n",
        "latent_dim = 2\n",
        "num_examples_to_generate = 16\n",
        "\n",
        "# keeping the random vector constant for generation (prediction) so\n",
        "# it will be easier to see the improvement.\n",
        "random_vector_for_generation = tf.random.normal(\n",
        "    shape=[num_examples_to_generate, latent_dim])\n",
        "model = CVAE(latent_dim)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RmdVsmvhPxyy"
      },
      "outputs": [],
      "source": [
        "def generate_and_save_images(model, epoch, test_sample):\n",
        "  mean, logvar = model.encode(test_sample)\n",
        "  z = model.reparameterize(mean, logvar)\n",
        "  predictions = model.sample(z)\n",
        "  fig = plt.figure(figsize=(4, 4))\n",
        "\n",
        "  for i in range(predictions.shape[0]):\n",
        "    plt.subplot(4, 4, i + 1)\n",
        "    plt.imshow(predictions[i, :, :, 0], cmap='gray')\n",
        "    plt.axis('off')\n",
        "\n",
        "  # tight_layout minimizes the overlap between 2 sub-plots\n",
        "  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "swCyrbqQQ-Ri"
      },
      "outputs": [],
      "source": [
        "# Pick a sample of the test set for generating output images\n",
        "assert batch_size &gt;= num_examples_to_generate\n",
        "for test_batch in test_dataset.take(1):\n",
        "  test_sample = test_batch[0:num_examples_to_generate, :, :, :]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2M7LmLtGEMQJ"
      },
      "outputs": [],
      "source": [
        "generate_and_save_images(model, 0, test_sample)\n",
        "\n",
        "for epoch in range(1, epochs + 1):\n",
        "  start_time = time.time()\n",
        "  for train_x in train_dataset:\n",
        "    train_step(model, train_x, optimizer)\n",
        "  end_time = time.time()\n",
        "\n",
        "  loss = tf.keras.metrics.Mean()\n",
        "  for test_x in test_dataset:\n",
        "    loss(compute_loss(model, test_x))\n",
        "  elbo = -loss.result()\n",
        "  display.clear_output(wait=False)\n",
        "  print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'\n",
        "        .format(epoch, elbo, end_time - start_time))\n",
        "  generate_and_save_images(model, epoch, test_sample)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P4M_vIbUi7c0"
      },
      "source": [
        "### 最後のトレーニングエポックから生成された画像を表示する"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WfO5wCdclHGL"
      },
      "outputs": [],
      "source": [
        "def display_image(epoch_no):\n",
        "  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5x3q9_Oe5q0A"
      },
      "outputs": [],
      "source": [
        "plt.imshow(display_image(epoch))\n",
        "plt.axis('off')  # Display images"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NywiH3nL8guF"
      },
      "source": [
        "### 保存したすべての画像のアニメーション GIF を表示する"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IGKQgENQ8lEI"
      },
      "outputs": [],
      "source": [
        "anim_file = 'cvae.gif'\n",
        "\n",
        "with imageio.get_writer(anim_file, mode='I') as writer:\n",
        "  filenames = glob.glob('image*.png')\n",
        "  filenames = sorted(filenames)\n",
        "  last = -1\n",
        "  for i, filename in enumerate(filenames):\n",
        "    frame = 2*(i**0.5)\n",
        "    if round(frame) &gt; round(last):\n",
        "      last = frame\n",
        "    else:\n",
        "      continue\n",
        "    image = imageio.imread(filename)\n",
        "    writer.append_data(image)\n",
        "  image = imageio.imread(filename)\n",
        "  writer.append_data(image)\n",
        "\n",
        "import IPython\n",
        "if IPython.version_info &gt;= (6, 2, 0, ''):\n",
        "  display.Image(filename=anim_file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ZqAEtdqUmJF"
      },
      "outputs": [],
      "source": [
        "import tensorflow_docs.vis.embed as embed\n",
        "embed.embed_file(anim_file)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PeunRU6TSumT"
      },
      "source": [
        "### 潜在空間から数字の 2D 多様体を表示する\n",
        "\n",
        "次のコードを実行すると、各数字が 2D 潜在空間で別の数字に変化する、さまざまな数字クラスの連続分布が表示されます。潜在空間の標準正規分布の生成には、[TensorFlow Probability](https://www.tensorflow.org/probability) を使用します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "code",
        "id": "mNcaaYPBS3mj"
      },
      "outputs": [],
      "source": [
        "def plot_latent_images(model, n, digit_size=28):\n",
        "  \"\"\"Plots n x n digit images decoded from the latent space.\"\"\"\n",
        "\n",
        "  norm = tfp.distributions.Normal(0, 1)\n",
        "  grid_x = norm.quantile(np.linspace(0.05, 0.95, n))\n",
        "  grid_y = norm.quantile(np.linspace(0.05, 0.95, n))\n",
        "  image_width = digit_size*n\n",
        "  image_height = image_width\n",
        "  image = np.zeros((image_height, image_width))\n",
        "\n",
        "  for i, yi in enumerate(grid_x):\n",
        "    for j, xi in enumerate(grid_y):\n",
        "      z = np.array([[xi, yi]])\n",
        "      x_decoded = model.sample(z)\n",
        "      digit = tf.reshape(x_decoded[0], (digit_size, digit_size))\n",
        "      image[i * digit_size: (i + 1) * digit_size,\n",
        "            j * digit_size: (j + 1) * digit_size] = digit.numpy()\n",
        "\n",
        "  plt.figure(figsize=(10, 10))\n",
        "  plt.imshow(image, cmap='Greys_r')\n",
        "  plt.axis('Off')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F-ZG69QCZnGY"
      },
      "outputs": [],
      "source": [
        "plot_latent_images(model, 20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HrJRef8Ln945"
      },
      "source": [
        "## 次のステップ\n",
        "\n",
        "このチュートリアルでは、TensorFlow を使用して畳み込み変分オートエンコーダを実装する方法を実演しました。\n",
        "\n",
        "次のステップとして、ネットワークサイズを増加し、モデル出力の改善を試みることができます。たとえば、`Conv2D` と `Conv2DTranspose` の各レイヤーの `filter` おパラメータを 512 に設定することができます。最終的な 2D 潜在画像プロットを生成するには、`latent_dim` を 2 に維持することが必要であることに注意してください。また、ネットワークサイズが高まると、トレーニング時間も増大します。\n",
        "\n",
        "また、CIFAR-10 などのほかのデータセットを使って VAE を実装してみるのもよいでしょう。\n",
        "\n",
        "VAE は、さまざまなスタイルや複雑さの異なるスタイルで実装することができます。その他の実装については、次のリンクをご覧ください。\n",
        "\n",
        "- [変分オートエンコーダ (keras.io)](https://keras.io/examples/generative/vae/)\n",
        "- [「カスタムレイヤーとモデルを記述する」ガイドの VAE サンプル（tensorflow.org）](https://www.tensorflow.org/guide/keras/custom_layers_and_models#putting_it_all_together_an_end-to-end_example)\n",
        "- [TFP 確率レイヤー: 変分オートエンコーダ](https://www.tensorflow.org/probability/examples/Probabilistic_Layers_VAE)\n",
        "\n",
        "VAE をさらに詳しく学習するには、「[An Introduction to Variational Autoencoders](https://arxiv.org/abs/1906.02691)」をご覧ください。"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "cvae.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
